{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "VYNA79KmgvbY"
      },
      "source": [
        "Copyright 2018 The Dopamine Authors.\n",
        "\n",
        "Licensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with the License. You may obtain a copy of the License at\n",
        "\n",
        "https://www.apache.org/licenses/LICENSE-2.0\n",
        "\n",
        "Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "emUEZEvldNyX"
      },
      "source": [
        "# Dopamine: How to create and train a custom agent\n",
        "\n",
        "This colab demonstrates how to create a variant of a provided agent (Example 1) and how to create a new agent from\n",
        "scratch (Example 2).\n",
        "\n",
        "Run all the cells below in order."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "8QrqbfDJX0JZ"
      },
      "source": [
        "When running the following cell (\"Install necessary packages\") you will see a notice indicating that you need to restart your runtime. Hit restart, and then continue running the cells below it (from \"Necessary imports and globals\") onward."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "Ckq6WG-seC7F"
      },
      "outputs": [],
      "source": [
        "# @title Install necessary packages.\n",
        "!pip install -U dopamine-rl\n",
        "!pip install pandas==0.24.2  # Needed to be able to load the pickle files.\n",
        "!pip install --upgrade gym\n",
        "!pip install gym[atari,accept-rom-license]"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "WzwZoRKxdFov"
      },
      "outputs": [],
      "source": [
        "# @title Necessary imports and globals.\n",
        "\n",
        "import numpy as np\n",
        "import os\n",
        "from dopamine.agents.dqn import dqn_agent\n",
        "from dopamine.discrete_domains import run_experiment\n",
        "from dopamine.colab import utils as colab_utils\n",
        "from absl import flags\n",
        "import gin\n",
        "\n",
        "BASE_PATH = '/tmp/colab_dope_run'  # @param\n",
        "GAME = 'Asterix'  # @param"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "EFY3tTITHugq"
      },
      "outputs": [],
      "source": [
        "# @title Load baseline data\n",
        "!gsutil -q -m cp -R gs://download-dopamine-rl/preprocessed-benchmarks/* /content/\n",
        "experimental_data = colab_utils.load_baselines('/content')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "bidurBV0djGi"
      },
      "source": [
        "## Example 1: Train a modified version of DQN\n",
        "Asterix is one of the standard agents provided with Dopamine.\n",
        "The purpose of this example is to demonstrate how one can modify an existing agent. The modification\n",
        "we are doing here (choosing actions randomly) is for illustrative purposes: it will clearly perform very\n",
        "poorly."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "PUBRSmX6dfa3"
      },
      "outputs": [],
      "source": [
        "# @title Create an agent based on DQN, but choosing actions randomly.\n",
        "\n",
        "LOG_PATH = os.path.join(BASE_PATH, 'random_dqn', GAME)\n",
        "\n",
        "class MyRandomDQNAgent(dqn_agent.DQNAgent):\n",
        "  def __init__(self, sess, num_actions):\n",
        "    \"\"\"This maintains all the DQN default argument values.\"\"\"\n",
        "    super().__init__(sess, num_actions)\n",
        "\n",
        "  def step(self, reward, observation):\n",
        "    \"\"\"Calls the step function of the parent class, but returns a random action.\n",
        "    \"\"\"\n",
        "    super().step(reward, observation)\n",
        "    return np.random.randint(self.num_actions)\n",
        "\n",
        "def create_random_dqn_agent(sess, environment, summary_writer=None):\n",
        "  \"\"\"The Runner class will expect a function of this type to create an agent.\"\"\"\n",
        "  return MyRandomDQNAgent(sess, num_actions=environment.action_space.n)\n",
        "\n",
        "random_dqn_config = \"\"\"\n",
        "import dopamine.discrete_domains.atari_lib\n",
        "import dopamine.discrete_domains.run_experiment\n",
        "atari_lib.create_atari_environment.game_name = '{}'\n",
        "atari_lib.create_atari_environment.sticky_actions = True\n",
        "run_experiment.Runner.num_iterations = 200\n",
        "run_experiment.Runner.training_steps = 10\n",
        "run_experiment.Runner.max_steps_per_episode = 100\n",
        "\"\"\".format(GAME)\n",
        "gin.parse_config(random_dqn_config, skip_unknown=False)\n",
        "\n",
        "# Create the runner class with this agent. We use very small numbers of steps\n",
        "# to terminate quickly, as this is mostly meant for demonstrating how one can\n",
        "# use the framework.\n",
        "random_dqn_runner = run_experiment.TrainRunner(LOG_PATH, create_random_dqn_agent)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "WuWFGwGHfkFp"
      },
      "outputs": [],
      "source": [
        "# @title Train MyRandomDQNAgent.\n",
        "print('Will train agent, please be patient, may be a while...')\n",
        "random_dqn_runner.run_experiment()\n",
        "print('Done training!')"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "IknanILXX4Zz"
      },
      "outputs": [],
      "source": [
        "# @title Load the training logs.\n",
        "random_dqn_data = colab_utils.read_experiment(\n",
        "    LOG_PATH, verbose=True, summary_keys=['train_episode_returns'])\n",
        "random_dqn_data['agent'] = 'MyRandomDQN'\n",
        "random_dqn_data['run_number'] = 1\n",
        "experimental_data[GAME] = experimental_data[GAME].merge(random_dqn_data,\n",
        "                                                        how='outer')"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "mSOVFUKN-kea"
      },
      "outputs": [],
      "source": [
        "# @title Plot training results.\n",
        "\n",
        "import seaborn as sns\n",
        "import matplotlib.pyplot as plt\n",
        "\n",
        "fig, ax = plt.subplots(figsize=(16,8))\n",
        "sns.lineplot(\n",
        "    x='iteration', y='train_episode_returns', hue='agent',\n",
        "    data=experimental_data[GAME], ax=ax)\n",
        "plt.title(GAME)\n",
        "plt.show()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "8T0yfWPw-7QZ"
      },
      "source": [
        "## Example 2: Train an agent built from scratch.\n",
        "The purpose of this example is to demonstrate how one can create an agent from scratch. The agent\n",
        "created here is meant to demonstrate the bare minimum functionality that is expected from agents. It is\n",
        "selecting actions in a very suboptimal way, so it will clearly do poorly."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "1kgV__YU-_ET"
      },
      "outputs": [],
      "source": [
        "# @title Create a completely new agent from scratch.\n",
        "\n",
        "LOG_PATH = os.path.join(BASE_PATH, 'sticky_agent', GAME)\n",
        "\n",
        "class StickyAgent(object):\n",
        "  \"\"\"This agent randomly selects an action and sticks to it. It will change\n",
        "  actions with probability switch_prob.\"\"\"\n",
        "  def __init__(self, sess, num_actions, switch_prob=0.1):\n",
        "    self._sess = sess\n",
        "    self._num_actions = num_actions\n",
        "    self._switch_prob = switch_prob\n",
        "    self._last_action = np.random.randint(num_actions)\n",
        "    self.eval_mode = False\n",
        "\n",
        "  def _choose_action(self):\n",
        "    if np.random.random() \u003c= self._switch_prob:\n",
        "      self._last_action = np.random.randint(self._num_actions)\n",
        "    return self._last_action\n",
        "\n",
        "  def bundle_and_checkpoint(self, checkpoint_dir, iteration):\n",
        "    del checkpoint_dir, iteration  # Unused.\n",
        "\n",
        "  def unbundle(self, checkpoint_dir, checkpoint_version, data):\n",
        "    del checkpoint_dir, checkpoint_version, data  # Unused.\n",
        "\n",
        "  def begin_episode(self, observation):\n",
        "    del observation  # Unused.\n",
        "    return self._choose_action()\n",
        "\n",
        "  def end_episode(self, reward):\n",
        "    del reward  # Unused.\n",
        "\n",
        "  def step(self, reward, observation):\n",
        "    return self._choose_action()\n",
        "\n",
        "def create_sticky_agent(sess, environment, summary_writer=None):\n",
        "  \"\"\"The Runner class will expect a function of this type to create an agent.\"\"\"\n",
        "  return StickyAgent(sess, num_actions=environment.action_space.n,\n",
        "                     switch_prob=0.2)\n",
        "\n",
        "sticky_config = \"\"\"\n",
        "import dopamine.discrete_domains.atari_lib\n",
        "import dopamine.discrete_domains.run_experiment\n",
        "atari_lib.create_atari_environment.game_name = '{}'\n",
        "atari_lib.create_atari_environment.sticky_actions = True\n",
        "run_experiment.Runner.num_iterations = 200\n",
        "run_experiment.Runner.training_steps = 10\n",
        "run_experiment.Runner.max_steps_per_episode = 100\n",
        "\"\"\".format(GAME)\n",
        "gin.parse_config(sticky_config, skip_unknown=False)\n",
        "\n",
        "# Create the runner class with this agent. We use very small numbers of steps\n",
        "# to terminate quickly, as this is mostly meant for demonstrating how one can\n",
        "# use the framework.\n",
        "sticky_runner = run_experiment.TrainRunner(LOG_PATH, create_sticky_agent)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "gQt3t_IS_Gku"
      },
      "outputs": [],
      "source": [
        "# @title Train StickyAgent.\n",
        "print('Will train sticky agent, please be patient, may be a while...')\n",
        "sticky_runner.run_experiment()\n",
        "print('Done training!')"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "oom0wB0A_Qb8"
      },
      "outputs": [],
      "source": [
        "# @title Load the training logs.\n",
        "sticky_data = colab_utils.read_experiment(\n",
        "    LOG_PATH, verbose=True, summary_keys=['train_episode_returns'])\n",
        "sticky_data['agent'] = 'StickyAgent'\n",
        "sticky_data['run_number'] = 1\n",
        "experimental_data[GAME] = experimental_data[GAME].merge(sticky_data,\n",
        "                                                        how='outer')"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "DqsagPbb_Xjm"
      },
      "outputs": [],
      "source": [
        "# @title Plot training results.\n",
        "\n",
        "import seaborn as sns\n",
        "import matplotlib.pyplot as plt\n",
        "\n",
        "fig, ax = plt.subplots(figsize=(16,8))\n",
        "sns.lineplot(\n",
        "    x='iteration', y='train_episode_returns', hue='agent',\n",
        "    data=experimental_data[GAME], ax=ax)\n",
        "plt.title(GAME)\n",
        "plt.show()"
      ]
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "collapsed_sections": [],
      "last_runtime": {
        "build_target": "//learning/deepmind/public/tools/ml_python:ml_notebook",
        "kind": "private"
      },
      "name": "agents.ipynb",
      "private_outputs": true,
      "provenance": [
        {
          "file_id": "/piper/depot/dopamine//colab/agents.ipynb",
          "timestamp": 1655145788020
        }
      ],
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
